// SPDX-FileCopyrightText: Copyright 2025-2025 深圳市同心圆网络有限公司
// SPDX-License-Identifier: GPL-3.0-only

package trace_dao

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"math"
	"strings"
	"time"

	"gitcode.com/opendragonfly/df_proto_gen_go.git/trace_api"
	"github.com/dgraph-io/badger/v4"
	"google.golang.org/protobuf/proto"
)

const (
	_CONSUME_TIME_INDEX_NAME_SPACE     = "consume"
	_CONSUME_TIME_ALL_INDEX_NAME_SPACE = "consumeall"
)

type ConsumeTimeKey struct {
	TimeStr        string //年月日小时 比如 20250103
	ServiceName    string //服务名
	ServiceVersion string //服务版本
	RootSpanName   string //根Span名称
	ConsumeTime    uint32 //消费时间
}

func (key *ConsumeTimeKey) ToBytes(withTime bool) ([]byte, error) {
	buf := &bytes.Buffer{}

	ns := _CONSUME_TIME_INDEX_NAME_SPACE
	if key.RootSpanName == "" {
		ns = _CONSUME_TIME_ALL_INDEX_NAME_SPACE
	}
	name := strings.Join([]string{ns, key.TimeStr, key.ServiceName, key.ServiceVersion, key.RootSpanName}, "\t")
	_, err := buf.Write([]byte(name))
	if err != nil {
		return nil, err
	}
	if withTime {
		timeBytes := make([]byte, 4)
		binary.BigEndian.PutUint32(timeBytes, key.ConsumeTime)
		_, err = buf.Write(timeBytes)
		if err != nil {
			return nil, err
		}
	}
	return buf.Bytes(), nil
}

func ParseConsumeTimeKey(buf []byte, withTime bool) (*ConsumeTimeKey, error) {
	l := len(buf)
	if withTime && l < 4 {
		return nil, fmt.Errorf("wrong key buf")
	}
	name := ""
	consumeTime := uint32(0)
	if withTime {
		name = string(buf[0 : l-4])
		consumeTime = binary.BigEndian.Uint32(buf[l-4:])
	} else {
		name = string(buf)
	}
	parts := strings.Split(name, "\t")
	if len(parts) != 5 {
		return nil, fmt.Errorf("wrong key buf")
	}

	return &ConsumeTimeKey{
		TimeStr:        parts[1],
		ServiceName:    parts[2],
		ServiceVersion: parts[3],
		RootSpanName:   parts[4],
		ConsumeTime:    consumeTime,
	}, nil
}

type _ConsumeTimeIterWrap struct {
	iter  *badger.Iterator
	key   *ConsumeTimeKey
	value *trace_api.TraceIdWithTimeList
}

func newConsumeTimeIterWrap(iter *badger.Iterator) *_ConsumeTimeIterWrap {
	return &_ConsumeTimeIterWrap{
		iter:  iter,
		key:   nil,
		value: nil,
	}
}

func (wrap *_ConsumeTimeIterWrap) Prepare() error {
	if wrap.key != nil && wrap.value != nil {
		return nil
	}

	for ; wrap.iter.Valid(); wrap.iter.Next() {
		item := wrap.iter.Item()
		if item.IsDeletedOrExpired() {
			continue
		}
		key, err := ParseConsumeTimeKey(item.KeyCopy(nil), true)
		if err != nil {
			return err
		}
		valueData, err := item.ValueCopy(nil)
		if err != nil {
			return err
		}
		value := &trace_api.TraceIdWithTimeList{}
		err = proto.Unmarshal(valueData, value)
		if err != nil {
			return err
		}
		wrap.key = key
		wrap.value = value
		wrap.iter.Next()
		return nil
	}
	return nil
}

func (wrap *_ConsumeTimeIterWrap) GetKey() *ConsumeTimeKey {
	return wrap.key
}

func (wrap *_ConsumeTimeIterWrap) GetValue() *trace_api.TraceIdWithTimeList {
	return wrap.value
}

func (wrap *_ConsumeTimeIterWrap) ClearKeyValue() {
	wrap.key = nil
	wrap.value = nil
}

func (wrap *_ConsumeTimeIterWrap) Close() {
	wrap.iter.Close()
}

type _ConsumeTimeIndex struct{}

func (index *_ConsumeTimeIndex) Insert(txn *badger.Txn, consumeKey *ConsumeTimeKey, traceWithTime *trace_api.TraceIdWithTime, keepTraceDay uint) error {
	if consumeKey.ConsumeTime == 0 {
		return nil
	}
	key, err := consumeKey.ToBytes(true)
	if err != nil {
		return err
	}
	item, err := txn.Get(key)
	if err != nil {
		if err == badger.ErrKeyNotFound {
			value, err := proto.Marshal(&trace_api.TraceIdWithTimeList{
				ItemList: []*trace_api.TraceIdWithTime{
					traceWithTime,
				},
			})
			if err != nil {
				return err
			}
			entry := badger.NewEntry(key, value).WithTTL(time.Duration(keepTraceDay) * time.Hour * 24)
			return txn.SetEntry(entry)
		}
		return err
	}
	if item.IsDeletedOrExpired() {
		value, err := proto.Marshal(&trace_api.TraceIdWithTimeList{
			ItemList: []*trace_api.TraceIdWithTime{
				traceWithTime,
			},
		})
		if err != nil {
			return err
		}
		entry := badger.NewEntry(key, value).WithTTL(time.Duration(keepTraceDay) * time.Hour * 24)
		return txn.SetEntry(entry)
	}

	value, err := item.ValueCopy(nil)
	if err != nil {
		return err
	}

	traceItemList := &trace_api.TraceIdWithTimeList{}
	err = proto.Unmarshal(value, traceItemList)
	if err != nil {
		return err
	}
	if traceItemList.ItemList == nil {
		traceItemList.ItemList = []*trace_api.TraceIdWithTime{}
	}
	traceItemList.ItemList = append(traceItemList.ItemList, traceWithTime)
	l := len(traceItemList.ItemList)
	if l > 100 {
		traceItemList.ItemList = traceItemList.ItemList[l-100:]
	}
	newValue, err := proto.Marshal(traceItemList)
	if err != nil {
		return err
	}
	entry := badger.NewEntry(key, newValue).WithTTL(time.Duration(keepTraceDay) * time.Hour * 24)
	return txn.SetEntry(entry)
}

func (index *_ConsumeTimeIndex) List(txn *badger.Txn, serviceInfo *trace_api.ServiceInfo,
	filterByRootSpanName bool, rootSpanName string,
	fromTime, toTime int64, limit uint32) ([]string, error) {
	if !filterByRootSpanName {
		rootSpanName = ""
	}

	timeStrMap := map[string]bool{}
	for curTime := fromTime; curTime <= toTime; curTime += 3600 * 1000 {
		timeStr := time.UnixMilli(curTime).Format("2006-01-02 15")
		timeStrMap[timeStr] = true
	}
	timeStr := time.UnixMilli(toTime).Format("2006-01-02 15")
	timeStrMap[timeStr] = true

	iterWrapList := []*_ConsumeTimeIterWrap{}

	defer func() {
		for _, iterWrap := range iterWrapList {
			iterWrap.Close()
		}
	}()

	for timeStr := range timeStrMap {
		consumeKey := &ConsumeTimeKey{
			TimeStr:        timeStr,
			ServiceName:    serviceInfo.ServiceName,
			ServiceVersion: serviceInfo.ServiceVersion,
			RootSpanName:   rootSpanName,
			ConsumeTime:    0,
		}
		keyPrefix, err := consumeKey.ToBytes(false)
		if err != nil {
			return nil, err
		}
		options := badger.IteratorOptions{
			PrefetchValues: true,
			PrefetchSize:   100,
			Reverse:        true,
			AllVersions:    false,
			Prefix:         keyPrefix,
		}
		iter := txn.NewIterator(options)
		iter.Rewind()
		consumeKey.ConsumeTime = math.MaxUint32
		tmpkey, err := consumeKey.ToBytes(true)
		if err != nil {
			return nil, err
		}
		iter.Seek(tmpkey)
		iterWrapList = append(iterWrapList, newConsumeTimeIterWrap(iter))
	}
	traceIdList := []string{}

	for {
		for _, iterWrap := range iterWrapList {
			err := iterWrap.Prepare()
			if err != nil {
				return nil, err
			}
		}
		var maxIterWrap *_ConsumeTimeIterWrap
		for _, iterWrap := range iterWrapList {
			if iterWrap.GetKey() == nil {
				continue
			}
			if maxIterWrap == nil {
				maxIterWrap = iterWrap
			} else if iterWrap.GetKey().ConsumeTime > maxIterWrap.GetKey().ConsumeTime {
				maxIterWrap = iterWrap
			}
		}
		if maxIterWrap == nil {
			return traceIdList, nil
		}
		for _, item := range maxIterWrap.GetValue().ItemList {
			if item.StartTimeStamp >= fromTime && item.StartTimeStamp <= toTime {
				traceIdList = append(traceIdList, item.TraceId)
				if len(traceIdList) >= int(limit) {
					return traceIdList, nil
				}
			}
		}
		maxIterWrap.ClearKeyValue()
	}
}
